# ------------------------------------------------------------------------------
# Build comorbidities for full cohort
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=10000]" bash 09_build_comorbidities.sh
# ------------------------------------------------------------------------------

# Setup ------------------------------------------------------------------------
library(yaml) # read_yaml absolute filepaths
library(data.table)
library(here) # here() relative filepaths
library(testit) # assert function
library(tidyverse)
library(glue)

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "03_analysis", "02_prep_and_summarize_cohort", "temp")

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
zocat_comorbidities <- read_csv(here("lib",  "zocats_to_comorbidities.csv"))
ccs_comorbidities <- read_csv(here("lib", "ccs_to_comorbidities.csv"))
comorbidity_wts_raw <- read_csv(here("lib",  "comorbidity_weight.csv"))
demographics <- readRDS(paths$analysis$demographics) %>%
  select(ed_enc_id, sex_female, age_at_admit)

# Clean xwalk ------------------------------------------------------------------
message("Cleaning comorbidities ")
zocats <- zocat_comorbidities %>%
  mutate(feature = map(zocat, ~ glue("dia_t2yall_zc_cat_name_count_{.x}"))) %>%
  select(feature, comorbidity)
ccs <- ccs_comorbidities %>%
  setnames("ccs", "feature")
comorbidities_xwalk <- zocats %>% rbind(ccs)

feature_names <- comorbidities_xwalk$feature
comorbidity_names <- comorbidities_xwalk$comorbidity %>% unique()
unmatched <- comorbidity_names[!(comorbidity_names %in% comorbidity_wts_raw$comorbidity)]
print(unmatched)
assert(
  "All comorbidities present in weight xwalk",
  all(comorbidity_names %in% comorbidity_wts_raw$comorbidity)
)

comorbidity_wts_women <- comorbidity_wts_raw$hr_women
comorbidity_wts_men <- comorbidity_wts_raw$hr_men

names(comorbidity_wts_women) <- comorbidity_wts_raw$comorbidity
names(comorbidity_wts_men) <- comorbidity_wts_raw$comorbidity

comorbidities_df <- tibble(
  ed_enc_id = double(), age_at_admit = double(), sex_female = double()
)
for (x in comorbidity_names) {
  print(x)
  comorbidities_df[[x]] <- double()
  weight_name <- glue("weight_{x}")
  comorbidities_df[[c(weight_name)]] <- double()
}

# Unused comorbidities ---------------------------------------------------------
all_comors <- comorbidity_wts_raw$comorbidity
unused_comors <- all_comors[!(all_comors %in% comorbidity_names)]
if (length(unused_comors > 0)) {
  message(
    "WARNING: the following comorbidities are not used in hazard calculations"
  )
  for (i in unused_comors) {
    print(i)
  }
}

# Construct comorbidities -----------------------------------------------------
message("Constructing comorbidities variables...")
for (sample_split in c("val", "test", "train")) {
  # message("Getting comorbiditiess for ", split, "...")
  split <- "random"
  features_split <- readRDS(glue(paths$features[[sample_split]])) %>%
    as.matrix()

  assert(
    "Comorbidities features in features df",
    all(feature_names %in% colnames(features_split))
  )

  features_split <- features_split[, unlist(feature_names)]
  ids_split <- readRDS(paths$cohort[[sample_split]]) %>%
    u$safe_left_join(demographics) %>%
    select(ed_enc_id, age_at_admit, sex_female) %>%
    setDT()
  features_split <- features_split > 0
  for (x in comorbidity_names) {
    comor_feats <- comorbidities_xwalk %>%
      filter(comorbidity == x) %>%
      .[["feature"]] %>%
      unlist()
    x_comorbidity <- features_split[, comor_feats] %>%
      as.matrix()

    weight_name <- glue("weight_{x}")

    ids_split[[x]] <- rowSums(x_comorbidity) > 0
    ids_split[[weight_name]] <- case_when(
      # ids_split[[x]] == 0 ~ 1,
      ids_split$sex_female == 1 ~ comorbidity_wts_women[x]^ids_split[[x]],
      ids_split$sex_female == 0 ~ comorbidity_wts_men[x]^ids_split[[x]]
    )
  }
  comorbidities_df <- rbind(comorbidities_df, ids_split)
}

# Print means ------------------------------------------------------------------
message("Getting comorbidities rates in full population...")
for (comorbid in comorbidity_names) {
  comorbid_rate <- mean(comorbidities_df[[comorbid]])
  message(comorbid, " rate: ", comorbid_rate)
}

# Save -------------------------------------------------------------------------
save_fp <- file.path(temp, "comorbidities_df.rds")
message("Saving comorbidities to ", save_fp, "...")
write_rds(comorbidities_df, save_fp)

message("Done.")
